# train whisper on EA WRC audios

from dataclasses import dataclass
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from loguru import logger
from typing import List, Any, Dict, Union
import random
import torch
import evaluate
from datasets import Audio, Dataset
from pydub import AudioSegment
import json
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration, WhisperProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer


DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
BASE_MODEL = 'openai/whisper-small'
feature_extractor = WhisperFeatureExtractor.from_pretrained(BASE_MODEL)
tokenizer = WhisperTokenizer.from_pretrained(BASE_MODEL, language="english", task="transcribe")
processor = WhisperProcessor.from_pretrained(BASE_MODEL, language="english", task="transcribe")
EVALUATE_PATH = '/root/github/evaluate'
metric = evaluate.load(EVALUATE_PATH + '/metrics/wer/wer.py')

ROOT_PATH = Path('/root/ea sounds')
SOUND_PATH = ROOT_PATH / 'raw'
# OUTPUT_PATH = ROOT_PATH / 'output'

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)


def get_all_available_files():
    result = []
    for filename in tqdm(Path(SOUND_PATH).rglob('*.wav')):
        if not Path(filename).name.startswith('cd'):
            result.append(filename)

    logger.info(f'got {len(result)} files')
    return result

def get_sound_label(sound_file: str):
    filename = Path(sound_file).name.lower()
    # remove .wav
    filename = filename[:filename.index('.wav')]
    # split by '_'
    parts = filename.split('_')
    # ends with number? and len > 1
    if len(parts) > 1 and parts[-1].isnumeric():
        # redundant file, drop the number
        text = ' '.join(parts[:-1])
    else:
        text = ' '.join(parts)
    
    return text

def preprocess_data(files: List[str], data_count=30000, output_path=ROOT_PATH / 'out', output_json='data.json', is_training=True):
    data = []

    for d in tqdm(range(data_count)):
        obj = {}

        # randomly select 2-3 files
        files_count = random.randint(2, 3)
        selected_files = random.sample(files, files_count)
        # concat them
        sounds = AudioSegment.from_wav(selected_files[0])
        for i in range(1, len(selected_files)):
            sound = AudioSegment.from_wav(selected_files[i])
            empty = AudioSegment.silent(random.randint(500, 1500))  # empty silent 0.5-1.5s
            sounds = sounds + empty + sound
        
        # export the sounds
        output_filename = output_path / f'{d}.wav'
        sounds.export(output_filename, format='wav')

        # get label
        labels = [get_sound_label(Path(f).name) for f in selected_files]
        label = ' '.join(labels)

        obj['file'] = str(output_filename)
        obj['text'] = label
        obj['origin_files'] = [str(f) for f in selected_files]

        data.append(obj)

    if is_training:
        # should append all files
        data.extend([
            {'file': str(f), 'text': get_sound_label(f.name), 'origin_files': [str(f)]}
            for f in files
            ])
    
    # shuffle
    random.shuffle(data)

    # save data
    with open(ROOT_PATH / output_json, 'w', encoding='utf8') as f:
        json.dump(data, f, indent=4)

def get_dataset(descriptor_path: str):
    # load as data descriptor
    with open(descriptor_path, 'r', encoding='utf8') as f:
        data = json.load(f)

    # dict_data = {}
    # for i in range(len(data)):
    #     dict_data[data[i]['file']] = {'text': data[i]['text'], 'origin_files': data[i]['origin_files']}
    
    to_be_converted = {'audio': [f['file'] for f in data], 'label': [f['text'] for f in data]}

    audio_ds = Dataset.from_dict(to_be_converted).cast_column('audio', Audio(sampling_rate=16000))
    return audio_ds

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], device=DEVICE).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["label"]).input_ids
    return batch

def load_model():
    model = WhisperForConditionalGeneration.from_pretrained(BASE_MODEL)
    model.generation_config.language = 'english'
    model.generation_config.task = "transcribe"
    model.generation_config.forced_decoder_ids = None

    return model

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
    
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

def train():
    model = load_model()
    ds_train = get_dataset(ROOT_PATH / 'data_train.json')
    ds_train = ds_train.map(prepare_dataset)
    ds_test = get_dataset(ROOT_PATH / 'data_test.json')
    ds_test = ds_test.map(prepare_dataset)
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id,
    )

    training_args = Seq2SeqTrainingArguments(
        output_dir="./whisper-small-hi",  # change to a repo name of your choice
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
        learning_rate=1e-5,
        warmup_steps=500,
        max_steps=2000,
        gradient_checkpointing=True,
        fp16=True,
        evaluation_strategy="steps",
        per_device_eval_batch_size=8,
        predict_with_generate=True,
        generation_max_length=225,
        save_steps=100,
        eval_steps=100,
        logging_steps=25,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        # push_to_hub=True,
    )
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=ds_train,
        eval_dataset=ds_test,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=processor.feature_extractor,
    )

    trainer.train() # let's go
    trainer.save_model(ROOT_PATH / 'checkpoint')

    print(trainer.state.best_model_checkpoint)

if __name__ == "__main__":
    set_seed(42)
    files = get_all_available_files()
    preprocess_data(files, data_count=500, output_path=ROOT_PATH / 'train', output_json='data_train.json')
    preprocess_data(files, data_count=1500, output_path=ROOT_PATH / 'test', output_json='data_test.json', is_training=False)
    train()


